from Functions import *
from pandas import read_csv
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, median_absolute_error, r2_score

plt.style.use("Figures.mplstyle")

#----------------------------------------------------------------------------------------------------------------------------------
#                             Loading the Data and Filter It
#----------------------------------------------------------------------------------------------------------------------------------

data = read_csv("Method Test Energies.csv")
methods = ["O3LYP", "B3LYP", "M06L", "ZORA"]

yval = "Barrier"

exclude = []
filter = data["Index"] >= 0
for i in exclude:
	filter = filter & (data["Index"] != i)
data = data[filter]

withheldIndices = []
withheldMetals = []
toTrain = (data["Training"] == 1)
for index in withheldIndices:
	toTrain = toTrain & (data["Index"] != index)
for metal in withheldMetals:
	toTrain = toTrain & (data["Metal"] != metal)

training = data[toTrain]

toTest = toTrain & ~toTrain
for i in [11, 17, 23, 28, 29]:
	toTest = toTest | (data["Index"] == i)
testing = data[toTest]

scores = []
errs = []
looScores = []
looErrs = []
CVErrs = []
Fs = []

maxShortName = 0

for i in range(2*len(methods)):
	method = methods[i//2]

	if i%2 == 0:
		xs = [method+" ΔE_PCET"]
		numFTest = 1
		title = method+" PCET Only"
	else:
		xs = [method+" ΔE_PCET", method+" ΔE_PT", method+" ΔE_ET"]
		numFTest = 2
		title = method+" PCET, PT, and ET"
	model, F, pval, tErr, cvErr, looPredictions = fitAndEvaluate(xs, numFTest, training, yval)
	plotModel(xs, model, title, training, test=testing, yVal = yval,# figName = title,
		xTicksMajor=ticksMajorOxo, xTicksMinor=ticksMinorOxo, yTicksMajor=ticksMajorOxo, yTicksMinor=ticksMinorOxo)
	
	scores.append(r2_score(training[yval], model.predict(training[xs])))
	errs.append(mean_squared_error((training)[yval], model.predict(training[xs])))
	looScores.append(r2_score(training[yval], looPredictions))
	looErrs.append(mean_squared_error(training[yval], looPredictions))
	CVErrs.append(cvErr.mean())
	Fs.append(pval)
	if len(method) > maxShortName:
		maxShortName = len(method)
	

print("             SUMMARY OF RESULTS            ")
print("\n"+(' '*maxShortName)+" \tR^2\tErr\tLOO R^2\tLOO Err\tCV Err\tSignificant?")
for i in range(2*len(methods)):
	title = '' if i%2 != 0 else methods[i//2]
	print(((maxShortName-len(title))*' ')+title+":\t"+"{:.2f}".format(scores[i])+"\t"+"{:.2f}".format(errs[i])+"\t"+
		"{:.2f}".format(looScores[i])+"\t"+"{:.2f}".format(looErrs[i])+"\t"+"{:.2f}".format(CVErrs[i])+"\t"+str(Fs[i]))
input()